Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Revert change that removed the option to run OffloadModel with out activation checkpointing. #608

Merged
merged 13 commits into from
Apr 15, 2021

Conversation

anj-s
Copy link
Contributor

@anj-s anj-s commented Apr 14, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

  • checkpoint_activation was removed incorrectly. Reverted SyncShard change since it makes the code more readable inspite of having another path. Working on code refactoring but wanted to get this change checked in the meantime.
  • Added tests since we did not detect the missing codepath
  • Modified benchmarks to account for the checkpoint_activation flag. Will be adding a benchmark in an upcoming PR once I have a few more run tabulated.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 14, 2021
@anj-s anj-s requested a review from blefaudeux April 14, 2021 20:32
@@ -292,6 +292,75 @@ def backward(ctx, *grad_outputs): # type: ignore
return (None, None) + grads


class ShardSyncLayer(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at least this part I'm a bit familiar with :)

@@ -386,4 +455,23 @@ def forward(self, *inputs: Any, **_: Any) -> Any:

# We need the second param to be a dummy input to enable the
# backward pass to be triggered for integer inputs.
return ActivationCheckpointing.apply(*inputs, torch.tensor([], requires_grad=True), self)
if self._checkpoint_activation:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I must have reviewed the offending PR and missed that, sorry about that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! I realized that the tests weren't really catching this so glad I realized it.

offload_model.train()
pred = offload_model(input)
loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking elsewhere for some form of parity ? wondering just in case

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM especially since revert, else some missing bits of the big picture, checking 1:1

@anj-s anj-s merged commit a77c56f into master Apr 15, 2021
@anj-s anj-s deleted the revert-sync-shard branch April 15, 2021 02:50
@min-xu-ai min-xu-ai mentioned this pull request Apr 19, 2021
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants